Skip to content

NNX: fix Linen-parity gaps + unit tests#4255

Merged
copybara-service[bot] merged 1 commit into
mainfrom
fix/nnx-linen-parity-gaps
Jun 27, 2026
Merged

NNX: fix Linen-parity gaps + unit tests#4255
copybara-service[bot] merged 1 commit into
mainfrom
fix/nnx-linen-parity-gaps

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Description

NNX (pure_nnx=True) had Linen-only / silently-divergent gaps across train / loss / decoder / metrics / GRPO. This closes them and adds correctness unit tests. The fixes apply on main independently; PR #3526 (flip defaults to NNX) makes them the live default path, and the UTs pin the behavior either way.

Fixes

# Issue
1 skip_step_on_spikes: silent no-op on NNX (apply_gradients didn't forward loss/grad_norm). Now forwarded + metric surfaced.
2 loss_fn: NNX checked vocab-tiling before indexer warm-up; reordered to match Linen.
3 NNX decoder logits guards used self.model_mode instead of the call-arg model_mode.
4 routed_bias: updates silently dropped on NNX (Linen "intermediates" prefix absent on NNX dict). Now matched by suffix.
5 record_internal_nn_metrics: KeyError on NNX. Now NNX-aware via suffix collection.
6 qwix: crashed under pure_nnx with the bridged decoder; bridge now skips qwix's non-Variable attrs + a config guard rejects bridged-decoder+qwix.
7 maxengine.set_engine_vars_from_base_engine: AttributeError on NNX; now uses get_kv_cache_annotations_nnx.
8 GRPO gradient_accumulation_steps>1: NotImplementedError on NNX. Implemented; also fixed the GA loss metric (sum/GA, not sum/total_weights).
9 GRPO scan_layers=False: NotImplementedError on NNX. Guard removed (NNX policy already matches the inference layout).
10 GRPO optimizer_memory_host_offload: ignored on NNX; now moves opt state to device before the update.

Also re-declared the legacy GRPO config fields (inference_replicas / inference_devices_per_replica /
inference_rollouts / use_pathways_reshard) in types.py — they were dropped from the schema so
grpo.yml couldn't load (pre-existing on main).

Tests

tests/unit/{train_nnx_test,grpo_nnx_test,maxengine_nnx_test,nnx_quant_guard_test}.py — 27 pass on CPU:

PYTHONPATH=src JAX_PLATFORMS=cpu python3 -m pytest \
  tests/unit/grpo_nnx_test.py tests/unit/train_nnx_test.py \
  tests/unit/maxengine_nnx_test.py tests/unit/nnx_quant_guard_test.py -q

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the fix/nnx-linen-parity-gaps branch from 0390217 to aa18ab3 Compare June 24, 2026 14:55
@codecov

codecov Bot commented Jun 24, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 59.37500% with 13 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 42.10% 8 Missing and 3 partials ⚠️
src/maxtext/common/metric_logger.py 66.66% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@ecnal-cienet ecnal-cienet force-pushed the fix/nnx-linen-parity-gaps branch from aa18ab3 to bdf087f Compare June 24, 2026 16:03
@ecnal-cienet ecnal-cienet changed the title Fix/nnx linen parity gaps NNX: fix Linen-parity gaps on the default path + unit tests Jun 24, 2026
@ecnal-cienet ecnal-cienet force-pushed the fix/nnx-linen-parity-gaps branch 3 times, most recently from 6c939f7 to 8113532 Compare June 25, 2026 21:33
@ecnal-cienet ecnal-cienet marked this pull request as ready for review June 25, 2026 21:35
@ecnal-cienet ecnal-cienet force-pushed the fix/nnx-linen-parity-gaps branch 3 times, most recently from 544f4bd to 35ca6a0 Compare June 25, 2026 23:09
@ecnal-cienet ecnal-cienet changed the title NNX: fix Linen-parity gaps on the default path + unit tests NNX: fix Linen-parity gaps + unit tests Jun 25, 2026
Comment thread src/maxtext/trainers/pre_train/train.py
Comment thread src/maxtext/trainers/pre_train/train.py
Comment thread src/maxtext/experimental/rl/grpo_trainer.py Outdated
@ecnal-cienet ecnal-cienet force-pushed the fix/nnx-linen-parity-gaps branch from 35ca6a0 to 5473916 Compare June 26, 2026 21:05
With pure_nnx/enable_nnx/pure_nnx_decoder defaulting to True, several
train/loss/decoder/metrics/GRPO paths diverged from Linen. Fixes:

- skip_step_on_spikes: forward loss/grad_norm through apply_gradients to the
  optax skip-step optimizer; read is_skipped back off the NNX optimizer.
- loss_fn: check the indexer dense-warmup before num_vocab_tiling (Linen order).
- decoder logits guards: use the model_mode call-arg, not self.model_mode.
- routed_bias read: dispatch the Linen intermediates path vs an NNX suffix match.
- record_activation_metrics: collect by path suffix so it works for Linen and
  NNX, scanned and unscanned (also fixes a pre-existing Linen KeyError).
- nnx_attrs_to_linen_vars: skip non-Variable attrs (qwix bookkeeping) not raise.
- config: error when qwix quant can't reach a bridged Linen decoder under pure_nnx.
- maxengine.set_engine_vars_from_base_engine: skip the quant copy and use the NNX
  kv-cache annotations on the NNX path.
- GRPO _train_step_nnx: gradient-accumulation scan loop; fix the GA loss metric.
- GRPO pathways reshard: drop the scan_layers=False NotImplementedError.
- GRPO host-offload: move optimizer state to device before the in-place update.

Tests: train_nnx_test, grpo_nnx_test, maxengine_nnx_test, nnx_quant_guard_test.
@ecnal-cienet ecnal-cienet force-pushed the fix/nnx-linen-parity-gaps branch from 5473916 to 9ce8edd Compare June 26, 2026 22:05
@copybara-service copybara-service Bot merged commit 87331f6 into main Jun 27, 2026
40 checks passed
@copybara-service copybara-service Bot deleted the fix/nnx-linen-parity-gaps branch June 27, 2026 00:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants